#include <iostream>

using namespace std;

struct ListNode {
    int val;
    ListNode *next;
    ListNode(int x) : val(x), next(NULL) {}
};

class Solution {
public:
    ListNode* removeElements(ListNode* head, int val) {
        ListNode new_head(0);
        new_head.next = head;
        ListNode* pre_node = &new_head;
        ListNode* node = head;
        while (node != nullptr) {
            if (node->val == val) {
                node = node->next;
                pre_node->next = node;
            } else {
                pre_node = node;
                node = node->next;
            }
        }
        return new_head.next;
    }
};